from fastchat.model import (
    get_conversation_template
)
import re
from GPT_4_judge.system_prompts import get_judge_system_prompt, get_naturalness_system_prompt_compact
from GPT_4_judge.language_models import APILiteLLM
from GPT_4_judge.loggers import logger

# from jailbreakbench import Classifier, Llama3JailbreakJudge
# import os

# from transformers import AutoTokenizer, AutoModelForCausalLM
# import torch

def load_judge(args):
    if "gpt" in args.judge_model:
        return GPTJudge(args)
    elif args.judge_model == "no-judge":
        return NoJudge(args)
    # elif args.judge_model == "jailbreakbench":
    #     return JBBJudge(os.environ["TOGETHER_API_KEY"])
    else:
        raise NotImplementedError

class JudgeBase:
    def __init__(self, args=None):
        if args is None:
            self.max_n_tokens = 10
            self.temperature = 0
            self.judge_name = "gpt-4-0125-preview"
        else: 
            self.max_n_tokens = args.judge_max_n_tokens
            self.temperature = args.judge_temperature
            self.judge_name = args.judge_model


    
    def get_judge_prompt(self, attack_prompt, target_response):
        return f"[ASSISTANT'S RESPONSE]: {target_response}"

    def score(self, prompt_list, response_list):
        raise NotImplementedError

    def process_output(self, raw_output):
        pattern = r'\[\[(\d+)\]\]'
        match = re.search(pattern, raw_output)
        output = int(match.group(1)) if match else None
        if output is None:
            logger.warning(f"Error in processing judge output: {raw_output}" )
            output = 102
        return output
               
class NoJudge(JudgeBase):
    def __init__(self, args):
        super(NoJudge, self).__init__(args)
        self.judge_model = None

    def score(self, prompt_list, response_list):
        return [1 for _ in prompt_list]

class GPTJudge(JudgeBase):
    def __init__(self, args):
        super(GPTJudge, self).__init__(args)
        self.judge_model = APILiteLLM(model_name = self.judge_name)

    def create_conv(self, full_prompt):
        conv = get_conversation_template(self.judge_name)
        conv.set_system_message(self.system_prompt)
        conv.append_message(conv.roles[0], full_prompt)
        return conv.to_openai_api_messages()

    def score(self, attack_prompt_list, target_response_list):
        self.system_prompt = get_judge_system_prompt(attack_prompt_list[0], target_response_list[0])
        convs_list = [self.create_conv(self.get_judge_prompt(prompt, response)) for prompt, response in zip(attack_prompt_list, target_response_list)]
        raw_outputs = self.judge_model.batched_generate(convs_list, 
                                                        max_n_tokens = self.max_n_tokens,
                                                        temperature = self.temperature,
                                                        top_p=1)
        outputs = [self.process_output(raw_output) for raw_output in raw_outputs]
        return outputs, raw_outputs
    
    def score_natural(self, attack_prompt_list, target_response_list):
        self.system_prompt = get_naturalness_system_prompt_compact(attack_prompt_list[0], target_response_list[0])
        convs_list = [self.create_conv(self.get_judge_prompt(prompt, response)) for prompt, response in zip(attack_prompt_list, target_response_list)]
        raw_outputs = self.judge_model.batched_generate(convs_list, 
                                                        max_n_tokens = self.max_n_tokens,
                                                        temperature = self.temperature,
                                                        top_p=1)
        outputs = [self.process_output(raw_output) for raw_output in raw_outputs]
        return outputs, raw_outputs
    

    
